import asyncio
import time
import numpy as np

from devicepilot.pylog.pylogger import PyLogger
# from devicepilot.py_pli.pylib import VUnits
from devicepilot.export import pli

from config_enum import eef_measurement_unit_enum as meas_enum
from config_enum import od_filter_wheel_enum as od_enum
from config_enum import detector_aperture_slider_enum as das_enum

from hw_abstraction.hal import HAL


# hal: HAL = VUnits.instance.hal
hal: HAL = pli.hal
meas_unit = hal.measurement_unit
das1 = hal.detector_aperture_slider1
od1 = hal.od_filter_wheel1
od2 = hal.od_filter_wheel2


async def pmt_fi_adjustment(pmt_serial_no='', ahrs=101.0, hv_start=0.3, hv_stop=0.73, hv_step=0.005, signal_target=2.2e5):
    exc_ms = 100
    low_power = True
    # Transmission (4, 7) = 0.000104713
    od1_filter_nr = 4
    od2_filter_nr = 7
    od1_pos = od1.get_config(od_enum.Positions.Offset) + (od1_filter_nr - 1) * 45
    od2_pos = od2.get_config(od_enum.Positions.Offset) + (od2_filter_nr - 1) * 45

    try:
        await hal.startup_hardware()
        await hal.initialize_device()
        await hal.home_movers()

        await meas_unit.set_config(meas_enum.PMT1.HighVoltageSettingFI, hv_start)
        await meas_unit.set_config(meas_enum.PMT1.AnalogHighRangeScale, ahrs)
        await meas_unit.set_config(meas_enum.PMT1.AnalogCountingEquivalent, 1.0)  # Disable analog to counting conversion
        await meas_unit.enable_pmt_hv_fi()
        await meas_unit.enable_flash_lamp_power(low_power)

        meas_unit.clear_measurements()
        guid = "6131ae47-f0c2-4d12-bfa7-b24fa13c4ac1"
        await meas_unit.load_fi_measurement(guid, measurement_time=exc_ms)

        await das1.move_to_named_position(das_enum.Positions.Aperture30)
        await asyncio.gather(od1.move(od1_pos), od2.move(od2_pos))

        hv_range = np.arange(hv_start, (hv_stop + 1e-6), hv_step).round(6)  # The high voltage setting scan range.
        fi_signal = np.zeros_like(hv_range)                                 # The measured FI signal.
        rel_error = np.zeros_like(hv_range)                                 # The measured relative error.

        PyLogger.logger.info(f"; hv    ; fi_signal    ; target_error ; ref_signal")

        for i, hv in enumerate(hv_range):
            await meas_unit.set_config(meas_enum.PMT1.HighVoltageSettingFI, hv)
            await meas_unit.enable_pmt_hv_fi()
            await asyncio.sleep(0.2)

            await meas_unit.execute_measurement(guid)
            results = await meas_unit.read_fi_measurement_results(guid, measurement_time=exc_ms, iref0=100000)
            fi_signal[i] = results._pmt1_analog_total
            rel_error[i] = np.abs(fi_signal[i] - signal_target) / signal_target
            PyLogger.logger.info(f"; {hv:5.3f} ; {fi_signal[i]:12.0f} ; {rel_error[i]:12.2%} ; {results.ref_signal}")

        hv = hv_range[rel_error.argmin()]
        PyLogger.logger.info(f"HighVoltageSetting_${{pmt}}_FI = {hv:.3f}")

    finally:
        await hal.shutdown_hardware()

    return 'pmt_fi_adjustment done'


async def fi_flash_test(duration=30, excitation=0.1, delay=0.2, xfl_hv=None):
    low_power = True
    # Transmission (1, 1) = 1.000000000
    od1_filter_nr = 1
    od2_filter_nr = 1
    od1_pos = od1.get_config(od_enum.Positions.Offset) + (od1_filter_nr - 1) * 45
    od2_pos = od2.get_config(od_enum.Positions.Offset) + (od2_filter_nr - 1) * 45

    try:
        await hal.startup_hardware()
        await hal.initialize_device()
        await hal.home_movers()

        if xfl_hv is not None and low_power:
            await meas_unit.set_config(meas_enum.FlashUnit.VoltageSettingLowPower, xfl_hv)
        if xfl_hv is not None and not low_power:
            await meas_unit.set_config(meas_enum.FlashUnit.VoltageSettingHighPower, xfl_hv)
        await meas_unit.enable_flash_lamp_power(low_power)

        meas_unit.clear_measurements()
        guid = "6131ae47-f0c2-4d12-bfa7-b24fa13c4ac1"
        await meas_unit.load_fi_measurement(guid, measurement_time=round(excitation * 1000))

        await das1.move_to_named_position(das_enum.Positions.Aperture30)
        await asyncio.gather(od1.move(od1_pos), od2.move(od2_pos))

        start = time.perf_counter()
        while (time.perf_counter() - start) <= duration:
            await meas_unit.execute_measurement(guid)
            await meas_unit.read_measurement_results(guid)
            await asyncio.sleep(delay)

    finally:
        await hal.shutdown_hardware()

    return 'fi_flash_test done'
